#pragma once

#include <map>
#include <vector>
#include <string>
#include <algorithm>
#include <memory>
#include <functional>
#include <string.h>
#include "model_state.h"
#include "ascend_util.h"
#include "triton/backend/backend_common.h"
#include "triton/backend/backend_model_instance.h"
#include "triton/backend/backend_input_collector.h"
#include "triton/backend/backend_output_responder.h"
#include "MxBase/E2eInfer/Model/Model.h"
#include "MxBase/E2eInfer/Tensor/Tensor.h"
#include "MxBase/DeviceManager/DeviceManager.h"

namespace triton { namespace backend { namespace ascend {

//
// ModelInstanceState
//
// State associated with a model instance. An object of this class is
// created and associated with each TRITONBACKEND_ModelInstance.
//
class ModelInstanceState : public BackendModelInstance {
public:
    static TRITONSERVER_Error* Create(ModelState* model_state,
        TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state);

    virtual ~ModelInstanceState();

    // Get the state of the model that corresponds to this instance.
    ModelState* StateForModel() const { return model_state_; }

    // Execute...
    void ProcessRequests(
        TRITONBACKEND_Request** requests, const uint32_t request_count);

private:
    ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance);

    TRITONSERVER_Error* VaildateInputs();

    TRITONSERVER_Error* VaildateOutputs();
    
    void Execute(
        std::vector<TRITONBACKEND_Response*>* responses, const uint32_t response_count,
        std::vector<MxBase::Tensor>& input_tensors, std::vector<MxBase::Tensor>& output_tensors);

    TRITONSERVER_Error* SetInputTensors(
        size_t total_batch_size, TRITONBACKEND_Request** requests, const uint32_t request_count,
        BackendInputCollector* collector, std::vector<MxBase::Tensor>& input_tensors,
        std::vector<const char*>& input_buffers);

    TRITONSERVER_Error* ReadOutputTensors(
        size_t total_batch_size, const std::vector<MxBase::Tensor>& output_tensors,
        TRITONBACKEND_Request** requests, const uint32_t request_count,
        std::vector<TRITONBACKEND_Response*>* responses);
    
    TRITONSERVER_Error* GetPaddingShape(std::vector<std::vector<uint32_t>>& padding_shapes);

    TRITONSERVER_Error* GetDynamicBatchPaddingShape(std::vector<uint32_t>& input_shape);

    TRITONSERVER_Error* GetDynamicHWPaddingShape(std::vector<uint32_t>& input_shape);
    
    TRITONSERVER_Error* GetDynamicDimsPaddingShape(std::vector<std::vector<uint32_t>>& input_shapes);

    bool IsSame(std::vector<std::vector<uint32_t>>& input_shapes,
        std::vector<std::vector<uint32_t>>& padding_shapes);

    TRITONSERVER_Error* TensorPadding(MxBase::Tensor& input_tensor, MxBase::Tensor& padding_tensor);
    
    TRITONSERVER_Error* TensorCrop(MxBase::Tensor& input_tensor, MxBase::Tensor& crop_tensor);
private:
    std::vector<std::string> model_input_names_ {};
    std::vector<std::string> model_output_names_ {};
    DataFormat input_format_;
    std::vector<DataFormat> output_format_;
    ModelState* model_state_;

    bool supports_batching_;

    DynamicInfo dynamic_info_;
    // The full path to the TorchScript model file.
    std::string model_path_;

    std::shared_ptr<MxBase::Model> ascend_model_;

    int32_t device_;

    std::map<std::string, MxBase::TensorDType> ModelConfigDTypeToMxBase {
        {"TYPE_BOOL", MxBase::TensorDType::BOOL},
        {"TYPE_UINT8", MxBase::TensorDType::UINT8},
        {"TYPE_UINT16", MxBase::TensorDType::UINT16},
        {"TYPE_UINT32", MxBase::TensorDType::UINT32},
        {"TYPE_UINT64", MxBase::TensorDType::UINT64},
        {"TYPE_INT8", MxBase::TensorDType::INT8},
        {"TYPE_INT16", MxBase::TensorDType::INT16},
        {"TYPE_INT32", MxBase::TensorDType::INT32},
        {"TYPE_INT64", MxBase::TensorDType::INT64},
        {"TYPE_FP16", MxBase::TensorDType::FLOAT16},
        {"TYPE_FP32", MxBase::TensorDType::FLOAT32},
        {"TYPE_FP64", MxBase::TensorDType::DOUBLE64},
        {"TYPE_INVALID", MxBase::TensorDType::UNDEFINED}
    };

    std::map<MxBase::TensorDType, std::string> MxBaseDTypeToModelConfig {
        {MxBase::TensorDType::BOOL, "TYPE_BOOL"},
        {MxBase::TensorDType::UINT8, "TYPE_UINT8"},
        {MxBase::TensorDType::UINT16, "TYPE_UINT16"},
        {MxBase::TensorDType::UINT32, "TYPE_UINT32"},
        {MxBase::TensorDType::UINT64, "TYPE_UINT64"},
        {MxBase::TensorDType::INT8, "TYPE_INT8"},
        {MxBase::TensorDType::INT16, "TYPE_INT16"},
        {MxBase::TensorDType::INT32, "TYPE_INT32"},
        {MxBase::TensorDType::INT64, "TYPE_INT64"},
        {MxBase::TensorDType::FLOAT16, "TYPE_FP16"},
        {MxBase::TensorDType::FLOAT32, "TYPE_FP32"},
        {MxBase::TensorDType::DOUBLE64, "TYPE_FP64"},
        {MxBase::TensorDType::UNDEFINED, "TYPE_INVALID"}
    };

    std::map<TRITONSERVER_DataType, MxBase::TensorDType> TrironDTypeToMxBase {
        {TRITONSERVER_TYPE_BOOL, MxBase::TensorDType::BOOL},
        {TRITONSERVER_TYPE_UINT8, MxBase::TensorDType::UINT8},
        {TRITONSERVER_TYPE_UINT16, MxBase::TensorDType::UINT16},
        {TRITONSERVER_TYPE_UINT32, MxBase::TensorDType::UINT32},
        {TRITONSERVER_TYPE_UINT64, MxBase::TensorDType::UINT64},
        {TRITONSERVER_TYPE_INT8, MxBase::TensorDType::INT8},
        {TRITONSERVER_TYPE_INT16, MxBase::TensorDType::INT16},
        {TRITONSERVER_TYPE_INT32, MxBase::TensorDType::INT32},
        {TRITONSERVER_TYPE_INT64, MxBase::TensorDType::INT64},
        {TRITONSERVER_TYPE_FP16, MxBase::TensorDType::FLOAT16},
        {TRITONSERVER_TYPE_FP32, MxBase::TensorDType::FLOAT32},
        {TRITONSERVER_TYPE_FP64, MxBase::TensorDType::DOUBLE64},
        {TRITONSERVER_TYPE_INVALID, MxBase::TensorDType::UNDEFINED}
    };

    std::map<MxBase::TensorDType, TRITONSERVER_DataType> MxBaseDTypeToTriton {
        {MxBase::TensorDType::BOOL, TRITONSERVER_TYPE_BOOL},
        {MxBase::TensorDType::UINT8, TRITONSERVER_TYPE_UINT8},
        {MxBase::TensorDType::UINT16, TRITONSERVER_TYPE_UINT16},
        {MxBase::TensorDType::UINT32, TRITONSERVER_TYPE_UINT32},
        {MxBase::TensorDType::UINT64, TRITONSERVER_TYPE_UINT64},
        {MxBase::TensorDType::INT8, TRITONSERVER_TYPE_INT8},
        {MxBase::TensorDType::INT16, TRITONSERVER_TYPE_INT16},
        {MxBase::TensorDType::INT32, TRITONSERVER_TYPE_INT32},
        {MxBase::TensorDType::INT64, TRITONSERVER_TYPE_INT64},
        {MxBase::TensorDType::FLOAT16, TRITONSERVER_TYPE_FP16},
        {MxBase::TensorDType::FLOAT32, TRITONSERVER_TYPE_FP32},
        {MxBase::TensorDType::DOUBLE64, TRITONSERVER_TYPE_FP64},
        {MxBase::TensorDType::UNDEFINED, TRITONSERVER_TYPE_INVALID}
    };
};

}}};